# Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""SubGraphInducer class used in nn.tf.Dataset."""


class SubGraphInducer(object):
  """ Induces SubGraph/HeteroSubGraphs.
  
  Args:
    use_neg: Whether to generate negative subgraphs, set to True if yes.
    edge_types: A list of edge types for heterogeneous subgraph. It must be 
      specified when the query does not fetch an edge of a certain type but 
      the induced edge_index needs it. For example, when getting neighbor nodes
      like q.outV('u-i'), the 'u-i' type will not be added in query dag, so it
      must be specified.
    use_edges: Whether to use edges data in SubGraph, set True if yes.
    node_types: A list of node types used in heterogeneous subgraphs. If you
      need to choose your own node types instead of using all of them in the
      GSL query, you need to specify this value.
    addl_types_and_shapes: A dict to describe the additional data of 
      BatchGraph which is generated by the induce_func. Each key is the name 
      of additional data, and values is a list [types, shapes], which are
      tf.dtype and tf.TensorShape instance to describe the tensor format 
      types and shapes of additional data.
  """
  def __init__(self,
               use_neg=False,
               edge_types=None, 
               use_edges=False, 
               node_types=None, 
               addl_types_and_shapes=None):
    self._user_neg = use_neg
    self._edge_types = edge_types
    self._use_edges = use_edges
    self._node_types = node_types
    self._addl_types_and_shapes = addl_types_and_shapes

  @property
  def node_types(self):
    return self._node_types

  @property
  def edge_types(self):
    return self._edge_types

  @property
  def use_edges(self):
    return self._use_edges

  @property
  def use_neg(self):
    return self._user_neg

  @property
  def addl_types_and_shapes(self):
    return self._addl_types_and_shapes

  def induce_func(self, data_dict):
    """`SubGraph`/`HeteroSubGraph` inducing function.
    
    Args:
      data_dict: GSL query results in numpy format.
    
    Returns:
      A tuple with 2 elements represents positive subgraphs and negative 
      subgraphs, each of which is a list of SubGraph/HeteroSubGraph. 
      Note that when there is not negative one, you must set `use_neg` 
      to False, and set the second element of the tuple to None.
    """
    raise NotImplementedError